# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""
TODO(T132414938) Add a module-level docstring
"""

from __future__ import annotations

import json
import logging
import re
import subprocess
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Sequence, Set

from . import UserError
from .errors import Errors
from .filesystem import get_filesystem

LOG: Logger = logging.getLogger(__name__)


class Configuration:
    def __init__(
        self, path: Path, json_contents: Optional[Dict[str, Any]] = None
    ) -> None:
        if json_contents is None:
            with open(path, "r") as configuration_file:
                json_contents = json.load(configuration_file)
        self._path: Path = path
        if path.name == ".pyre_configuration.local":
            self.is_local: bool = True
        else:
            self.is_local: bool = False
        self.root: str = str(path.parent)
        self.original_contents: Dict[str, Any] = json_contents

        # Configuration fields
        self.strict: Optional[bool] = json_contents.get("strict")
        self.targets: Optional[List[str]] = json_contents.get("targets")
        self.source_directories: Optional[List[str]] = json_contents.get(
            "source_directories"
        )
        self.version: Optional[str] = json_contents.get("version")
        self.pysa_version: Optional[str] = json_contents.get("pysa_version")
        self.use_buck_builder: Optional[bool] = json_contents.get("use_buck_builder")
        self.use_buck_source_database: Optional[bool] = json_contents.get(
            "use_buck_source_database"
        )
        self.ignore_all_errors: Optional[List[str]] = json_contents.get(
            "ignore_all_errors"
        )
        self.exclude: Optional[List[str]] = json_contents.get("exclude")
        self.use_buck2: Optional[bool] = json_contents.get("use_buck2")

    def get_contents(self) -> Dict[str, Any]:
        """Assumption: The field names in this class match the key names in
        the configuration."""
        contents: Dict[str, Any] = self.original_contents

        def update_contents(key: str) -> None:
            attribute = getattr(self, key)
            if attribute is not None:
                contents[key] = attribute
            elif key in contents:
                del contents[key]

        update_contents("targets")
        update_contents("source_directories")
        update_contents("version")
        update_contents("pysa_version")
        update_contents("strict")
        update_contents("use_buck_builder")
        update_contents("use_buck_source_database")
        update_contents("use_buck2")
        return contents

    @staticmethod
    def find_parent_file(
        filename: str, directory: Optional[Path] = None
    ) -> Optional[Path]:
        directory = directory or Path.cwd()
        root = directory.root
        while directory != root:
            configuration_path = directory / filename
            if configuration_path.is_file():
                return configuration_path
            parent = directory.parent
            if directory == parent:
                return None
            directory = parent
        return None

    @staticmethod
    def find_project_configuration(directory: Optional[Path] = None) -> Path:
        path = Configuration.find_parent_file(".pyre_configuration", directory)
        if path is None:
            raise UserError("No root with a `.pyre_configuration` found.")
        return path

    @staticmethod
    def find_local_configuration(directory: Optional[Path] = None) -> Optional[Path]:
        return Configuration.find_parent_file(".pyre_configuration.local", directory)

    @staticmethod
    def gather_local_configuration_paths(directory: str) -> Sequence[Path]:
        return [
            Path(path)
            for path in get_filesystem().list(
                directory, patterns=[r"**\.pyre_configuration.local"]
            )
        ]

    @staticmethod
    def gather_local_configurations() -> List["Configuration"]:
        LOG.info("Finding configurations...")
        configuration_paths = Configuration.gather_local_configuration_paths(".")
        if not configuration_paths:
            LOG.info("No projects with local configurations found.")
            return []
        configurations = []
        for configuration_path in configuration_paths:
            with open(configuration_path) as configuration_file:
                try:
                    configuration = Configuration(
                        configuration_path, json.load(configuration_file)
                    )
                    configurations.append(configuration)
                except json.decoder.JSONDecodeError:
                    LOG.error(
                        "Configuration at `%s` is invalid, skipping.",
                        configuration_path,
                    )
        LOG.info(
            "Found %d local configuration%s.",
            len(configurations),
            "s" if len(configurations) != 1 else "",
        )
        return configurations

    def get_path(self) -> Path:
        return self._path

    def get_source_paths(self) -> Generator[Path, None, None]:
        # This is an approximation
        return Path(self.root).glob("**/*.py")

    def get_exclude_as_patterns(self) -> Set[re.Pattern[str]]:
        if self.exclude is not None:
            return {re.compile(pattern) for pattern in self.exclude}
        else:
            return set()

    def _relative_path_from_prefix(self, prefix: str) -> Path:
        root = Path(self.root)
        if not prefix.startswith("//"):
            return root / prefix

        # strip preceding '//'
        stripped_prefix = prefix[2:]

        if self.is_local:
            global_root = self.find_project_configuration(root).parent
            return global_root / stripped_prefix
        else:
            return root / stripped_prefix

    def get_resolved_ignore_path_prefixes(self) -> Set[Path]:
        if self.ignore_all_errors is not None:
            return {
                self._relative_path_from_prefix(prefix).resolve()
                for prefix in self.ignore_all_errors
            }
        else:
            return set()

    def get_directory(self) -> Path:
        return self._path.parent

    def write(self) -> None:
        with open(self._path, "w") as configuration_file:
            json.dump(self.get_contents(), configuration_file, sort_keys=True, indent=2)
            configuration_file.write("\n")

    def remove_version(self) -> None:
        if not self.version:
            LOG.info("Version not found in configuration.")
            return
        self.version = None

    def delete(self) -> None:
        self._path.unlink()

    def set_version(self, version: str) -> None:
        self.version = version

    def set_pysa_version(self, pysa_version: str) -> None:
        self.pysa_version = pysa_version

    def enable_source_database_buck_builder(self) -> None:
        self.use_buck_builder = True
        self.use_buck_source_database = True

    def set_use_buck1_if_possible(self) -> None:
        if self.use_buck2 is None:
            self.use_buck2 = False

    def add_strict(self) -> None:
        if self.strict:
            LOG.info("Configuration is already strict.")
            return
        self.strict = True

    # Pyre is now strict by default, so in most cases we want to
    # remove the strict flag from the config file.
    def use_strict_default(self) -> None:
        self.strict = None

    def add_targets(self, targets: List[str]) -> None:
        existing_targets = self.targets
        if existing_targets:
            existing_targets.extend(targets)
        else:
            self.targets = targets

    def has_single_wildcard_target(self) -> bool:
        return (
            self.targets is not None
            and len(self.targets) == 1
            and any(t.endswith("...") for t in self.targets)
        )

    def deduplicate_targets(self) -> None:
        targets = self.targets
        if not targets:
            return
        glob_targets = [target for target in targets if target.endswith("/...")]
        non_glob_targets = [target for target in targets if not target.endswith("/...")]
        all_targets = sorted(set(glob_targets)) + sorted(set(non_glob_targets))
        deduplicated_targets = []
        expanded_targets = set()
        for target in all_targets:
            if target.endswith("/...") or target.endswith(":"):
                try:
                    expanded = (
                        subprocess.check_output(["buck2", "query", target])
                        .decode()
                        .strip()
                        .split("\n")
                    )
                    if not all(target in expanded_targets for target in expanded):
                        expanded_targets.update(expanded)
                        deduplicated_targets.append(target)
                except (FileNotFoundError, subprocess.CalledProcessError) as error:
                    LOG.warning("Failed to query target: %s\n%s", target, str(error))
                    deduplicated_targets.append(target)
            elif target not in expanded_targets:
                expanded_targets.add(target)
                deduplicated_targets.append(target)
        deduplicated_targets.sort(key=lambda target: targets.index(target))
        self.targets = deduplicated_targets

    def run_pyre(
        self,
        arguments: List[str],
        description: str,
        should_clean: bool,
        command_input: Optional[str],
        stderr_flag: "subprocess._FILE" = subprocess.PIPE,
    ) -> Optional["subprocess.CompletedProcess[str]"]:
        if should_clean:
            try:
                # If building targets, run clean or space may run out on device!
                LOG.info("Running `buck clean`...")
                subprocess.call(["buck", "clean"], timeout=200)
            except subprocess.TimeoutExpired:
                LOG.warning("Buck timed out. Try running `buck kill` before retrying.")
                return None
            except subprocess.CalledProcessError as error:
                LOG.warning("Error calling `buck clean`: %s", str(error))
                return None
        try:
            LOG.info("%s", description)
            return subprocess.run(
                ["pyre", *arguments],
                stdout=subprocess.PIPE,
                stderr=stderr_flag,
                text=True,
                input=command_input,
            )
        except subprocess.CalledProcessError as error:
            LOG.warning("Error calling pyre: %s", str(error))
            return None

    def get_nested_configuration_paths(self) -> Set[Path]:
        filesystem = get_filesystem()
        nested_configurations = filesystem.list(
            root=self.root, patterns=["**/.pyre_configuration.local"]
        )
        filtered_nested_configurations = {
            path
            for nested_configuration in nested_configurations
            if (path := Path(self.root, nested_configuration)) != self.get_path()
        }
        LOG.info(f"Found {len(filtered_nested_configurations)} nested configurations")
        LOG.debug(f"Nested configurations found: {filtered_nested_configurations}")
        return filtered_nested_configurations

    def get_errors(
        self,
        only_fix_error_code: Optional[int] = None,
        should_clean: bool = True,
        command_input: Optional[str] = None,
        strict: bool = False,
    ) -> Errors:
        local_root_arguments = (
            ["--local-configuration", self.root] if self.is_local else []
        )
        strict_arguments = ["--strict"] if strict else []
        arguments = [*strict_arguments, *local_root_arguments, "--output=json", "check"]
        pyre_output = self.run_pyre(
            arguments=arguments,
            description=f"Checking `{self.root}`...",
            should_clean=self.targets is not None and should_clean,
            command_input=command_input,
        )
        if not pyre_output:
            return Errors.empty()

        stdout = pyre_output.stdout
        if stdout is None:
            return Errors.empty()

        stdout = stdout.strip()

        try:
            errors = Errors.from_json(stdout, only_fix_error_code)
        except UserError as error:
            LOG.info("Error when parsing Pyre error output.")
            LOG.info(f"Pyre stdout: {stdout}\nPyre stderr: {pyre_output.stderr}")
            raise error

        LOG.info("Found %d error%s.", len(errors), "s" if len(errors) != 1 else "")
        return errors
